[PyTorch] Enable head dim 256 for FA4#2932
Conversation
bdcc02e to
3b3f7d0
Compare
Greptile SummaryThis PR enables head_dim=256 support for FlashAttention 4 by delegating head-dimension validation to FA4's own
Confidence Score: 4/5Safe to merge on SM100 hardware; the static head-dim table replacement and dedicated hd256 test are well-structured, but the single FA4 import statement still hard-fails for any FA4 build that does not export The core logic change — delegating head-dim validation to FA4's own function — is correct and the SM100-gated test properly signals intent. The remaining concern is that transformer_engine/pytorch/attention/dot_product_attention/backends.py — the grouped FA4 import at lines 167–171 is the single point where a missing Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[get_attention_backend called] --> B{use_flash_attention_4\nAND v4_is_installed\nAND v4_validate_head_dims ≠ None?}
B -- No --> Z[Skip FA4 head-dim validation]
B -- Yes --> C[Compute _fa4_alignment\n= 16 // element_size]
C --> D[Call v4_validate_head_dims\nhead_dim_qk, head_dim_v,\nsm_major, alignment]
D -- AssertionError --> E[log: unsupported head dims\nuse_flash_attention_4 = False]
D -- OK --> F{SM100 AND\nhd_qk == hd_v == 256\nAND max_seqlen_q != max_seqlen_kv?}
F -- Yes --> G[log: hd256 cross-attn fallback\nuse_flash_attention_4 = False]
F -- No --> H{is_training AND\nhd_qk != hd_v AND\nhd_qk >= 128 AND SM100?}
H -- Yes --> I{dV TMEM misalignment?\ntile_hdimv//2 % dk_reduce_ncol != 0}
I -- Yes --> J[log: MLA dV bug\nuse_flash_attention_4 = False]
I -- No --> K[FA4 enabled]
H -- No --> K
E --> L[Fall back to other backend]
G --> L
J --> L
Reviews (6): Last reviewed commit: "Merge branch 'main' into xiny/headdim256..." | Re-trigger Greptile |
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
|
@vcherepanov-nv @KshitijLakhani Please review. |
| # dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are | ||
| # misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's | ||
| # not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890. | ||
| if ( |
There was a problem hiding this comment.
Should this still be checked when FlashAttentionUtils.v4_validate_head_dims == None?
There was a problem hiding this comment.
I double checked that this is a bug of FA4. Kernels produce wrong results on these shapes but they're allowed by v4_validate_head_dims, so we have to filter them out manually.
Raise an issue to FA4. Dao-AILab/flash-attention#2552
|
LGTM |
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
Description
Need FA4 version
4.0.0b11.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: